from continual_rl.experiments.experiment import Experiment
from continual_rl.experiments.tasks.make_atari_task import get_single_atari_task
from continual_rl.experiments.tasks.make_minihack_task import get_single_minihack_task
from continual_rl.available_policies import LazyDict


def create_atari_sequence_loader(
    task_prefix,
    game_names,
    num_timesteps=5e7,
    max_episode_steps=None,
    full_action_space=False,
    continual_testing_freq=1000,
    cycle_count=1,
):
    def loader():
        tasks = [
            get_single_atari_task(
                f"{task_prefix}_{action_space_id}",
                action_space_id,
                name,
                num_timesteps,
                max_episode_steps=max_episode_steps,
                full_action_space=full_action_space
            ) for action_space_id, name in enumerate(game_names)
        ]

        return Experiment(
            tasks,
            continual_testing_freq=continual_testing_freq,
            cycle_count=cycle_count,
        )
    return loader


def create_atari_single_game_loader(env_name):
    return lambda: Experiment(tasks=[
        # Use the env name as the task_id if it's a 1:1 mapping between env and task (as "single game" implies)
        get_single_atari_task(env_name, 0, env_name, num_timesteps=5e7, max_episode_steps=10000)
    ])





def create_minihack_loader(
    task_prefix,
    env_name_pairs,
    num_timesteps=10e6,
    task_params=None,
    continual_testing_freq=1000,
    cycle_count=1,
):
    task_params = task_params if task_params is not None else {}

    def loader():
        tasks = []
        for action_space_id, pairs in enumerate(env_name_pairs):
            train_task = get_single_minihack_task(f"{task_prefix}_{action_space_id}", action_space_id, pairs[0],
                                                  num_timesteps, **task_params)
            eval_task = get_single_minihack_task(f"{task_prefix}_{action_space_id}_eval", action_space_id, pairs[1],
                                                 0, eval_mode=True, **task_params)

            tasks += [train_task, eval_task]

        return Experiment(
            tasks,
            continual_testing_freq=continual_testing_freq,
            cycle_count=cycle_count,
            num_timesteps=num_timesteps,
        )
    return loader


def get_available_experiments():

    experiments = LazyDict({

        # ===============================
        # ============ Atari ============
        # ===============================



        "atari_6_tasks_5_cycles": create_atari_sequence_loader(
            "atari_6_tasks_5_cycles",
            ["SpaceInvadersNoFrameskip-v4",
             "KrullNoFrameskip-v4",
             "BeamRiderNoFrameskip-v4",
             "HeroNoFrameskip-v4",
             "StarGunnerNoFrameskip-v4",
             "MsPacmanNoFrameskip-v4"],
            max_episode_steps=10000,
            num_timesteps=1e7,
            full_action_space=True,
            continual_testing_freq=1e6,
            cycle_count=2,
         ),

        # ===============================
        # ============ MiniHack =========
        # ===============================

        "minihack_nav_paired_2_cycles": create_minihack_loader(
            "minihack_nav_paired_2_cycles",
            [

                ("Room-Random-5x5-v0", "Room-Random-15x15-v0"),
                ("Corridor-R2-v0", "Corridor-R5-v0"),
                ("Room-Dark-5x5-v0", "Room-Dark-15x15-v0"),
                ("Corridor-R3-v0", "Corridor-R5-v0"),
                ("Room-Monster-5x5-v0", "Room-Monster-15x15-v0"),
                ("CorridorBattle-v0", "CorridorBattle-Dark-v0"),
                ("Room-Trap-5x5-v0", "Room-Trap-15x15-v0"),
                ("HideNSeek-v0", "HideNSeek-Big-v0"),
                ("Room-Ultimate-5x5-v0", "Room-Ultimate-15x15-v0"),
                ("HideNSeek-Lava-v0", "HideNSeek-Big-v0"),
            ],
            # num_timesteps=10e6,
            # continual_testing_freq=1e6,
            num_timesteps=1e6,
            continual_testing_freq=1e5,
            cycle_count=2,
        ),



    })

    return experiments
